We use DestVI and a single cell referenece dataset to infere the cell types present on each spot.
Here are the libraries we need.
from os import makedirs, remove, rmdir
from os.path import exists, join
from scipy.sparse import csr_matrix, issparse
from scvi.model import CondSCVI, DestVI
import anndata
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import scanpy as sc
import session_info
Global seed set to 0
Some useful function to save AnnData objects and models.
###############################################
def save_anndata(anndata_object, num, suffix):#
###############################################
if not issparse(anndata_object.X):
anndata_object.X = csr_matrix(anndata_object.X)
anndata_object.write_h5ad( join("output/h5ad", f"anndata.{num}.{suffix}.h5") )
###########################################################################
#################################################
def save_model(model_object, directory, suffix):#
#################################################
model_dir = join(directory, f"model_{suffix}")
model_archive = join(model_dir, "model.pt")
if exists(model_archive):
remove(model_archive)
if exists(model_dir):
rmdir(model_dir)
model_object.save(model_dir)
###########################################################################
############################
def has_zeros(name, adata):#
############################
genes = np.sum( 0 == np.sum(adata.X, axis=0) )
cells = np.sum( 0 == np.sum(adata.X, axis=1) )
print(f"{name}: {genes} zeros for genes and {cells} zeros for cells")
We create the output directory for this noteboook. Every outputs will save there.
try:
makedirs("output/h5ad")
except OSError as e:
pass
Configuration of scanpy.
sc.settings.figdir = "output"
sc.set_figure_params(figsize=(4, 4), frameon=False)
We load the single-cell reference dataset.
sc_adata = sc.read_10x_mtx("reference")
args = {
"filepath_or_buffer": join("reference", "types.tsv.gz"),
"header": None,
"compression": "gzip"
}
types = pd.read_csv(**args).rename(columns={0: "Type"})
types.index = sc_adata.obs_names
sc_adata.obs = types
sc_adata
AnnData object with n_obs × n_vars = 113507 × 20310
obs: 'Type'
var: 'gene_ids', 'feature_types'
We use the gene symbols as IDs and put convert everything to upper case. So, we also need to check that are no duplicates in gene symbols.
sc_adata.var_names = sc_adata.var_names.str.upper()
sc_adata = sc_adata[ : , ~ sc_adata.var_names.duplicated() ]
We need to remove the spots and the genes with no count.
sc_adata = sc_adata[ : , ~ np.all(sc_adata.X.toarray() == 0, axis=0) ]
sc_adata = sc_adata[ ~ np.all(sc_adata.X.toarray() == 0, axis=1) , : ]
save_anndata(sc_adata, "00", "raw")
/usr/local/lib/python3.10/dist-packages/anndata/_core/anndata.py:1235: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual. df[key] = c /usr/local/lib/python3.10/dist-packages/anndata/_core/anndata.py:1235: ImplicitModificationWarning: Trying to modify attribute `.var` of view, initializing view as actual. df[key] = c
We remove the spots with too few UMIs.
sc.pp.filter_genes(sc_adata, min_counts=10)
sc_adata.layers["counts"] = sc_adata.X.copy()
save_anndata(sc_adata, "01", "min_counts")
We select the most variable genes. This selection can create spots with no counts. So, we need to remove these spots.
sc.pp.highly_variable_genes(
sc_adata,
n_top_genes=2000,
subset=True,
layer="counts",
flavor="seurat_v3"
)
sc_adata.layers["counts"] = sc_adata.X.copy()
sc_adata = sc_adata[ ~ np.all( sc_adata.X.toarray() == 0 , axis=1 ) , : ]
sc_adata.layers["counts"] = sc_adata.X.copy()
save_anndata(sc_adata, "02", "var_genes")
We normalize and log-transform the counts.
sc.pp.normalize_total(sc_adata, target_sum=10e4)
save_anndata(sc_adata, "03", "normalized")
sc.pp.log1p(sc_adata)
sc_adata.raw = sc_adata
save_anndata(sc_adata, "04", "logtransformed")
We load the sample count and spatial data and create an AnnData object.
spatial = pd\
.read_csv("sample1.csv")\
.rename(columns={"PuckBarcode": "Barcode"})\
.set_index("SeqBarcode")
st_adata = sc.read_10x_mtx("sample1")
st_adata.obs = spatial.loc[ st_adata.obs.index ]
st_adata = st_adata[ : , ~ np.all(st_adata.X.toarray() == 0, axis=0) ]
We need to apply the same preprocessing as with the scRNA-seq reference dataset. So, we convert the gene symbols to upper case. Then, we normalize, log-transform the counts and remove the duplicates.
st_adata.var_names = st_adata.var_names.str.upper()
sc.pp.normalize_total(st_adata, target_sum=10e4)
sc.pp.log1p(st_adata)
st_adata.raw = st_adata
sc_adata = sc_adata[ : , ~ sc_adata.var.index.duplicated(keep=False) ]
st_adata = st_adata[ : , ~ st_adata.var.index.duplicated(keep=False) ]
We select only the genes that are common to the sample and the reference. The number of genes can end up being quite low because we took only the most variable genes.
intersect = np.intersect1d(sc_adata.var_names, st_adata.var_names)
print("{0} genes in common over {1}".format(len(intersect), sc_adata.shape[1]))
adata = sc_adata[:, intersect].copy()
adata = adata[ ~ np.all( adata.X.toarray() == 0 , axis=1 ) , : ]
adata.layers["counts"] = adata.X.copy()
st_adata = st_adata[:, intersect].copy()
st_adata = st_adata[ ~ np.all(st_adata.X.toarray() == 0, axis=1) , :]
st_adata.layers["counts"] = st_adata.X.copy()
st_adata.obsm["spatial"] = st_adata.obs[["x", "y"]].to_numpy()
has_zeros("refeference", adata)
has_zeros("spatial", st_adata)
1841 genes in common over 2000 refeference: 0 zeros for genes and 0 zeros for cells spatial: 0 zeros for genes and 0 zeros for cells
We can now set up the single-cell model.
adata.obs["Type"] = adata.obs.Type.astype(str) # CondSCVI fails otherwise
CondSCVI.setup_anndata(adata, layer="counts", labels_key="Type")
sc_model = CondSCVI(adata, weight_obs=False)
sc_model.view_anndata_setup()
/usr/local/lib/python3.10/dist-packages/scvi/data/fields/_layer_field.py:78: UserWarning: adata.layers[counts] does not contain unnormalized count data. Are you sure this is what you want? warnings.warn(
Anndata setup with scvi-tools version 0.16.4.
Setup via `CondSCVI.setup_anndata` with arguments:
{'labels_key': 'Type', 'layer': 'counts'}
Summary Statistics ┏━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ ┃ Summary Stat Key ┃ Value ┃ ┡━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ │ n_cells │ 113507 │ │ n_vars │ 1841 │ │ n_labels │ 18 │ └──────────────────┴────────┘
Data Registry ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Registry Key ┃ scvi-tools Location ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ X │ adata.layers['counts'] │ │ labels │ adata.obs['_scvi_labels'] │ └──────────────┴───────────────────────────┘
labels State Registry ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ ┃ Source Location ┃ Categories ┃ scvi-tools Encoding ┃ ┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ │ adata.obs['Type'] │ Astrocyte │ 0 │ │ │ CA1 │ 1 │ │ │ CA3 │ 2 │ │ │ Cajal_Retzius │ 3 │ │ │ Choroid │ 4 │ │ │ Denate │ 5 │ │ │ Endothelial_Stalk │ 6 │ │ │ Endothelial_Tip │ 7 │ │ │ Entorihinal │ 8 │ │ │ Ependymal │ 9 │ │ │ Interneuron │ 10 │ │ │ Microglia_Macrophages │ 11 │ │ │ Mural │ 12 │ │ │ Neurogenesis │ 13 │ │ │ Neuron.Slc17a6 │ 14 │ │ │ Oligodendrocyte │ 15 │ │ │ Polydendrocyte │ 16 │ │ │ nan │ 17 │ └───────────────────┴───────────────────────┴─────────────────────┘
We train the model here.
if False: # test mode
sc_model.train(max_epochs=2)
sc_model.history["elbo_train"].plot()
else: # normal mode
sc_model.train(max_epochs=150)
sc_model.history["elbo_train"].iloc[5:].plot()
plt.show()
save_model(sc_model, "output", "singlecell")
save_anndata(adata, "05", "trained")
/usr/local/lib/python3.10/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (ElboMetric). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Epoch 150/150: 100%|██████████| 150/150 [16:24<00:00, 6.57s/it, loss=659, v_num=1]
We can now set up the single-cell model.
DestVI.setup_anndata(st_adata, layer="counts")
st_model = DestVI.from_rna_model(st_adata, sc_model)
st_model.view_anndata_setup()
/usr/local/lib/python3.10/dist-packages/scvi/data/fields/_layer_field.py:78: UserWarning: adata.layers[counts] does not contain unnormalized count data. Are you sure this is what you want? warnings.warn(
Anndata setup with scvi-tools version 0.16.4.
Setup via `DestVI.setup_anndata` with arguments:
{'layer': 'counts'}
Summary Statistics ┏━━━━━━━━━━━━━━━━━━┳━━━━━━━┓ ┃ Summary Stat Key ┃ Value ┃ ┡━━━━━━━━━━━━━━━━━━╇━━━━━━━┩ │ n_cells │ 35049 │ │ n_vars │ 1841 │ └──────────────────┴───────┘
Data Registry ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Registry Key ┃ scvi-tools Location ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ │ X │ adata.layers['counts'] │ │ ind_x │ adata.obs['_indices'] │ └──────────────┴────────────────────────┘
We train the model here.
if False: # test mode
st_model.train(max_epochs=2)
st_model.history["elbo_train"].plot()
else: # normal mode
st_model.train(max_epochs=1500)
st_model.history["elbo_train"].iloc[10:].plot()
plt.show()
st_adata.obsm["proportions"] = st_model.get_proportions()
save_model(st_model, "output", "spatial")
save_anndata(st_adata, "06", "proportions")
/usr/local/lib/python3.10/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (ElboMetric). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Epoch 1500/1500: 100%|██████████| 1500/1500 [56:37<00:00, 2.27s/it, loss=4.34e+06, v_num=1]
As a first approximation, we just select the max probability for each spot.
for cell_type in st_adata.obsm["proportions"].columns:
data = st_adata.obsm["proportions"][cell_type].values
st_adata.obs[cell_type] = np.clip(data, 0, np.quantile(data, 0.99))
st_adata.obs["CellType"] = st_adata.obsm["proportions"].idxmax(axis=1)
sc.settings.figdir = "output"
sc.settings.file_format_figs = "png"
sc.pl.embedding(
st_adata,
basis="spatial",
color="CellType",
s=3,
save=f".cell_type_all",
show=True
)
WARNING: saving figure to file output/spatial.cell_type_all.png
We plot the probability values for each cell type and for each spot.
sc.settings.figdir = "output"
sc.settings.file_format_figs = "png"
sc.pl.embedding(
st_adata,
basis="spatial",
color=st_adata.obsm["proportions"].columns,
cmap="Reds",
s=3,
save=f".cell_type_grid",
show=True
)
WARNING: saving figure to file output/spatial.cell_type_grid.png
session_info.show()
----- anndata 0.8.0 matplotlib 3.5.2 numpy 1.21.5 pandas 1.3.5 scanpy 1.8.2 scipy 1.8.0 scvi 0.16.4 session_info 1.0.0 -----
OpenSSL 21.0.0 PIL 9.1.1 absl NA attr 21.2.0 backcall 0.2.0 bcrypt 3.2.0 beta_ufunc NA binom_ufunc NA boto3 1.23.8 botocore 1.26.8 bottleneck 1.3.2 brotli 1.0.9 certifi 2020.06.20 chardet 4.0.0 chex 0.1.3 colorama 0.4.5 cryptography 3.4.8 cycler 0.10.0 cython_runtime NA dateutil 2.8.2 debugpy 1.6.1 decorator 4.4.2 defusedxml 0.7.1 deprecate 0.3.1 docrep 0.3.2 entrypoints 0.4 etils 0.6.0 flatbuffers 2.0 flax 0.5.2 fsspec 2022.5.0 google NA h5py 3.7.0 hypergeom_ufunc NA idna 3.3 igraph 0.9.11 ipykernel 6.13.1 ipython_genutils 0.2.0 ipywidgets 7.7.1 jax 0.3.14 jaxlib 0.3.14 jedi 0.18.0 jmespath 1.0.0 joblib 0.17.0 jupyter_server 1.17.1 kiwisolver 1.3.2 leidenalg 0.8.10 llvmlite 0.38.1 louvain 0.7.1 lz4 4.0.0+dfsg matplotlib_inline NA mpl_toolkits NA msgpack 1.0.3 multipledispatch 0.6.0 natsort 8.1.0 nbinom_ufunc NA numba 0.55.2 numexpr 2.8.3 numpyro 0.10.0 opt_einsum v3.3.0 optax 0.1.2 packaging 21.3 parso 0.8.1 pexpect 4.8.0 pickleshare 0.7.5 pkg_resources NA portalocker 2.2.1 prompt_toolkit 3.0.30 psutil 5.9.0 ptyprocess 0.7.0 pydev_ipython NA pydevconsole NA pydevd 2.8.0 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.11.2 pyparsing 3.0.7 pyro 1.8.1 pytorch_lightning 1.5.10 pytz 2022.1 requests 2.27.1 rich NA ruamel NA simplejson 3.17.6 sinfo 0.3.4 sitecustomize NA six 1.16.0 sklearn 0.23.2 skmisc 0.1.4 storemagic NA tables 3.7.0 tensorboard 2.9.1 texttable 1.6.4 threadpoolctl 3.1.0 toolz 0.11.2 torch 1.12.0+cu102 torchdata 0.4.0 torchmetrics 0.9.2 torchtext 0.13.0 torchvision 0.13.0+cu102 tornado 6.1 tqdm 4.64.0 traitlets 5.3.0 tree 0.1.7 typing_extensions NA urllib3 1.26.9 wcwidth 0.2.5 wrapt 1.13.3 yaml 5.4.1 zmq 23.2.0 zope NA
----- IPython 7.31.1 jupyter_client 7.3.4 jupyter_core 4.10.0 jupyterlab 3.4.3 notebook 6.4.8 ----- Python 3.10.5 (main, Jun 8 2022, 09:26:22) [GCC 11.3.0] Linux-3.10.0-1160.62.1.el7.x86_64-x86_64-with-glibc2.35 ----- Session information updated at 2022-08-31 13:15